;;;; -*- mode: lisp; indent-tabs-mode: nil -*-
;;;; sm3.lisp -- implementation of SM3 (GM/T 0004-2012)

(in-package :crypto)
(in-ironclad-readtable)


;;;
;;; Parameters
;;;

(defconst +sm3-initial-state+
  #32@(#x7380166f #x4914b2b9 #x172442d7 #xda8a0600
       #xa96f30bc #x163138aa #xe38dee4d #xb0fb0e4e))


;;;
;;; SM3 rounds
;;;

(defmacro sm3-p0 (x)
  `(logxor ,x (rol32 ,x 9) (rol32 ,x 17)))

(defmacro sm3-p1 (x)
  `(logxor ,x (rol32 ,x 15) (rol32 ,x 23)))

(defmacro sm3-ee (w0 w7 w13 w3 w10)
  `(logxor (sm3-p1 (logxor ,w0 ,w7 (rol32 ,w13 15)))
           (rol32 ,w3 7)
           ,w10))

(defmacro sm3-ff (x y z)
  `(logior (logand ,x ,y)
           (logand (logior ,x ,y) ,z)))

(defmacro sm3-gg (x y z)
  `(logxor ,z (logand ,x (logxor ,y ,z))))

(defmacro sm3-r1 (a b c d e f g h tj wi wj)
  (let ((a12 (gensym))
        (tt0 (gensym))
        (tt1 (gensym))
        (tt2 (gensym)))
    `(let* ((,a12 (rol32 ,a 12))
            (,tt0 (rol32 (mod32+ ,a12 (mod32+ ,e,tj)) 7))
            (,tt1 (mod32+ (mod32+ (logxor ,a ,b ,c) ,d)
                          (mod32+ (logxor ,tt0 ,a12) ,wj)))
            (,tt2 (mod32+ (mod32+ (logxor ,e ,f ,g) ,h)
                          (mod32+ ,tt0 ,wi))))
       (declare (type (unsigned-byte 32) ,a12 ,tt0 ,tt1 ,tt2))
       (setf ,b (rol32 ,b 9)
             ,d ,tt1
             ,f (rol32 ,f 19)
             ,h (sm3-p0 ,tt2)))))

(defmacro sm3-r2 (a b c d e f g h tj wi wj)
  (let ((a12 (gensym))
        (tt0 (gensym))
        (tt1 (gensym))
        (tt2 (gensym)))
    `(let* ((,a12 (rol32 ,a 12))
            (,tt0 (rol32 (mod32+ ,a12 (mod32+ ,e ,tj)) 7))
            (,tt1 (mod32+ (mod32+ (sm3-ff ,a ,b ,c) ,d)
                          (mod32+ (logxor ,tt0 ,a12) ,wj)))
            (,tt2 (mod32+ (mod32+ (sm3-gg ,e ,f ,g) ,h)
                          (mod32+ ,tt0 ,wi))))
       (declare (type (unsigned-byte 32) ,a12 ,tt0 ,tt1 ,tt2))
       (setf ,b (rol32 ,b 9)
             ,d ,tt1
             ,f (rol32 ,f 19)
             ,h (sm3-p0 ,tt2)))))

(defun sm3-hash (state data start)
  (declare (type (simple-array (unsigned-byte 32) (8)) state)
           (type (simple-array (unsigned-byte 8) (*)) data)
           (type fixnum start)
           (optimize (speed 3) (space 0) (safety 0) (debug 0)))
  (let ((a (aref state 0))
        (b (aref state 1))
        (c (aref state 2))
        (d (aref state 3))
        (e (aref state 4))
        (f (aref state 5))
        (g (aref state 6))
        (h (aref state 7))
        (w00 (ub32ref/be data start))
        (w01 (ub32ref/be data (+ start 4)))
        (w02 (ub32ref/be data (+ start 8)))
        (w03 (ub32ref/be data (+ start 12)))
        (w04 (ub32ref/be data (+ start 16)))
        (w05 (ub32ref/be data (+ start 20)))
        (w06 (ub32ref/be data (+ start 24)))
        (w07 (ub32ref/be data (+ start 28)))
        (w08 (ub32ref/be data (+ start 32)))
        (w09 (ub32ref/be data (+ start 36)))
        (w10 (ub32ref/be data (+ start 40)))
        (w11 (ub32ref/be data (+ start 44)))
        (w12 (ub32ref/be data (+ start 48)))
        (w13 (ub32ref/be data (+ start 52)))
        (w14 (ub32ref/be data (+ start 56)))
        (w15 (ub32ref/be data (+ start 60))))
    (declare (type (unsigned-byte 32) a b c d e f g h)
             (type (unsigned-byte 32) w00 w01 w02 w03 w04 w05 w06 w07)
             (type (unsigned-byte 32) w08 w09 w10 w11 w12 w13 w14 w15))
    (sm3-r1 a b c d e f g h #x79cc4519 w00 (logxor w00 w04))
    (setf w00 (sm3-ee w00 w07 w13 w03 w10))
    (sm3-r1 d a b c h e f g #xf3988a32 w01 (logxor w01 w05))
    (setf w01 (sm3-ee w01 w08 w14 w04 w11))
    (sm3-r1 c d a b g h e f #xe7311465 w02 (logxor w02 w06))
    (setf w02 (sm3-ee w02 w09 w15 w05 w12))
    (sm3-r1 b c d a f g h e #xce6228cb w03 (logxor w03 w07))
    (setf w03 (sm3-ee w03 w10 w00 w06 w13))
    (sm3-r1 a b c d e f g h #x9cc45197 w04 (logxor w04 w08))
    (setf w04 (sm3-ee w04 w11 w01 w07 w14))
    (sm3-r1 d a b c h e f g #x3988a32f w05 (logxor w05 w09))
    (setf w05 (sm3-ee w05 w12 w02 w08 w15))
    (sm3-r1 c d a b g h e f #x7311465e w06 (logxor w06 w10))
    (setf w06 (sm3-ee w06 w13 w03 w09 w00))
    (sm3-r1 b c d a f g h e #xe6228cbc w07 (logxor w07 w11))
    (setf w07 (sm3-ee w07 w14 w04 w10 w01))
    (sm3-r1 a b c d e f g h #xcc451979 w08 (logxor w08 w12))
    (setf w08 (sm3-ee w08 w15 w05 w11 w02))
    (sm3-r1 d a b c h e f g #x988a32f3 w09 (logxor w09 w13))
    (setf w09 (sm3-ee w09 w00 w06 w12 w03))
    (sm3-r1 c d a b g h e f #x311465e7 w10 (logxor w10 w14))
    (setf w10 (sm3-ee w10 w01 w07 w13 w04))
    (sm3-r1 b c d a f g h e #x6228cbce w11 (logxor w11 w15))
    (setf w11 (sm3-ee w11 w02 w08 w14 w05))
    (sm3-r1 a b c d e f g h #xc451979c w12 (logxor w12 w00))
    (setf w12 (sm3-ee w12 w03 w09 w15 w06))
    (sm3-r1 d a b c h e f g #x88a32f39 w13 (logxor w13 w01))
    (setf w13 (sm3-ee w13 w04 w10 w00 w07))
    (sm3-r1 c d a b g h e f #x11465e73 w14 (logxor w14 w02))
    (setf w14 (sm3-ee w14 w05 w11 w01 w08))
    (sm3-r1 b c d a f g h e #x228cbce6 w15 (logxor w15 w03))
    (setf w15 (sm3-ee w15 w06 w12 w02 w09))
    (sm3-r2 a b c d e f g h #x9d8a7a87 w00 (logxor w00 w04))
    (setf w00 (sm3-ee w00 w07 w13 w03 w10))
    (sm3-r2 d a b c h e f g #x3b14f50f w01 (logxor w01 w05))
    (setf w01 (sm3-ee w01 w08 w14 w04 w11))
    (sm3-r2 c d a b g h e f #x7629ea1e w02 (logxor w02 w06))
    (setf w02 (sm3-ee w02 w09 w15 w05 w12))
    (sm3-r2 b c d a f g h e #xec53d43c w03 (logxor w03 w07))
    (setf w03 (sm3-ee w03 w10 w00 w06 w13))
    (sm3-r2 a b c d e f g h #xd8a7a879 w04 (logxor w04 w08))
    (setf w04 (sm3-ee w04 w11 w01 w07 w14))
    (sm3-r2 d a b c h e f g #xb14f50f3 w05 (logxor w05 w09))
    (setf w05 (sm3-ee w05 w12 w02 w08 w15))
    (sm3-r2 c d a b g h e f #x629ea1e7 w06 (logxor w06 w10))
    (setf w06 (sm3-ee w06 w13 w03 w09 w00))
    (sm3-r2 b c d a f g h e #xc53d43ce w07 (logxor w07 w11))
    (setf w07 (sm3-ee w07 w14 w04 w10 w01))
    (sm3-r2 a b c d e f g h #x8a7a879d w08 (logxor w08 w12))
    (setf w08 (sm3-ee w08 w15 w05 w11 w02))
    (sm3-r2 d a b c h e f g #x14f50f3b w09 (logxor w09 w13))
    (setf w09 (sm3-ee w09 w00 w06 w12 w03))
    (sm3-r2 c d a b g h e f #x29ea1e76 w10 (logxor w10 w14))
    (setf w10 (sm3-ee w10 w01 w07 w13 w04))
    (sm3-r2 b c d a f g h e #x53d43cec w11 (logxor w11 w15))
    (setf w11 (sm3-ee w11 w02 w08 w14 w05))
    (sm3-r2 a b c d e f g h #xa7a879d8 w12 (logxor w12 w00))
    (setf w12 (sm3-ee w12 w03 w09 w15 w06))
    (sm3-r2 d a b c h e f g #x4f50f3b1 w13 (logxor w13 w01))
    (setf w13 (sm3-ee w13 w04 w10 w00 w07))
    (sm3-r2 c d a b g h e f #x9ea1e762 w14 (logxor w14 w02))
    (setf w14 (sm3-ee w14 w05 w11 w01 w08))
    (sm3-r2 b c d a f g h e #x3d43cec5 w15 (logxor w15 w03))
    (setf w15 (sm3-ee w15 w06 w12 w02 w09))
    (sm3-r2 a b c d e f g h #x7a879d8a w00 (logxor w00 w04))
    (setf w00 (sm3-ee w00 w07 w13 w03 w10))
    (sm3-r2 d a b c h e f g #xf50f3b14 w01 (logxor w01 w05))
    (setf w01 (sm3-ee w01 w08 w14 w04 w11))
    (sm3-r2 c d a b g h e f #xea1e7629 w02 (logxor w02 w06))
    (setf w02 (sm3-ee w02 w09 w15 w05 w12))
    (sm3-r2 b c d a f g h e #xd43cec53 w03 (logxor w03 w07))
    (setf w03 (sm3-ee w03 w10 w00 w06 w13))
    (sm3-r2 a b c d e f g h #xa879d8a7 w04 (logxor w04 w08))
    (setf w04 (sm3-ee w04 w11 w01 w07 w14))
    (sm3-r2 d a b c h e f g #x50f3b14f w05 (logxor w05 w09))
    (setf w05 (sm3-ee w05 w12 w02 w08 w15))
    (sm3-r2 c d a b g h e f #xa1e7629e w06 (logxor w06 w10))
    (setf w06 (sm3-ee w06 w13 w03 w09 w00))
    (sm3-r2 b c d a f g h e #x43cec53d w07 (logxor w07 w11))
    (setf w07 (sm3-ee w07 w14 w04 w10 w01))
    (sm3-r2 a b c d e f g h #x879d8a7a w08 (logxor w08 w12))
    (setf w08 (sm3-ee w08 w15 w05 w11 w02))
    (sm3-r2 d a b c h e f g #x0f3b14f5 w09 (logxor w09 w13))
    (setf w09 (sm3-ee w09 w00 w06 w12 w03))
    (sm3-r2 c d a b g h e f #x1e7629ea w10 (logxor w10 w14))
    (setf w10 (sm3-ee w10 w01 w07 w13 w04))
    (sm3-r2 b c d a f g h e #x3cec53d4 w11 (logxor w11 w15))
    (setf w11 (sm3-ee w11 w02 w08 w14 w05))
    (sm3-r2 a b c d e f g h #x79d8a7a8 w12 (logxor w12 w00))
    (setf w12 (sm3-ee w12 w03 w09 w15 w06))
    (sm3-r2 d a b c h e f g #xf3b14f50 w13 (logxor w13 w01))
    (setf w13 (sm3-ee w13 w04 w10 w00 w07))
    (sm3-r2 c d a b g h e f #xe7629ea1 w14 (logxor w14 w02))
    (setf w14 (sm3-ee w14 w05 w11 w01 w08))
    (sm3-r2 b c d a f g h e #xcec53d43 w15 (logxor w15 w03))
    (setf w15 (sm3-ee w15 w06 w12 w02 w09))
    (sm3-r2 a b c d e f g h #x9d8a7a87 w00 (logxor w00 w04))
    (setf w00 (sm3-ee w00 w07 w13 w03 w10))
    (sm3-r2 d a b c h e f g #x3b14f50f w01 (logxor w01 w05))
    (setf w01 (sm3-ee w01 w08 w14 w04 w11))
    (sm3-r2 c d a b g h e f #x7629ea1e w02 (logxor w02 w06))
    (setf w02 (sm3-ee w02 w09 w15 w05 w12))
    (sm3-r2 b c d a f g h e #xec53d43c w03 (logxor w03 w07))
    (setf w03 (sm3-ee w03 w10 w00 w06 w13))
    (sm3-r2 a b c d e f g h #xd8a7a879 w04 (logxor w04 w08))
    (sm3-r2 d a b c h e f g #xb14f50f3 w05 (logxor w05 w09))
    (sm3-r2 c d a b g h e f #x629ea1e7 w06 (logxor w06 w10))
    (sm3-r2 b c d a f g h e #xc53d43ce w07 (logxor w07 w11))
    (sm3-r2 a b c d e f g h #x8a7a879d w08 (logxor w08 w12))
    (sm3-r2 d a b c h e f g #x14f50f3b w09 (logxor w09 w13))
    (sm3-r2 c d a b g h e f #x29ea1e76 w10 (logxor w10 w14))
    (sm3-r2 b c d a f g h e #x53d43cec w11 (logxor w11 w15))
    (sm3-r2 a b c d e f g h #xa7a879d8 w12 (logxor w12 w00))
    (sm3-r2 d a b c h e f g #x4f50f3b1 w13 (logxor w13 w01))
    (sm3-r2 c d a b g h e f #x9ea1e762 w14 (logxor w14 w02))
    (sm3-r2 b c d a f g h e #x3d43cec5 w15 (logxor w15 w03))
    (setf (aref state 0) (logxor (aref state 0) a)
          (aref state 1) (logxor (aref state 1) b)
          (aref state 2) (logxor (aref state 2) c)
          (aref state 3) (logxor (aref state 3) d)
          (aref state 4) (logxor (aref state 4) e)
          (aref state 5) (logxor (aref state 5) f)
          (aref state 6) (logxor (aref state 6) g)
          (aref state 7) (logxor (aref state 7) h))))


;;;
;;; Digest structures and functions
;;;

(defstruct (sm3
            (:constructor %make-sm3-digest nil)
            (:copier nil))
  (state (copy-seq +sm3-initial-state+)
         :type (simple-array (unsigned-byte 32) (8)))
  (count 0 :type (unsigned-byte 64))
  (buffer (make-array 64 :element-type '(unsigned-byte 8))
          :type (simple-array (unsigned-byte 8) (64)))
  (buffer-index 0 :type (integer 0 64)))

(defmethod reinitialize-instance ((state sm3) &rest initargs)
  (declare (ignore initargs))
  (replace (sm3-state state) +sm3-initial-state+)
  (setf (sm3-count state) 0
        (sm3-buffer-index state) 0)
  state)

(defmethod copy-digest ((state sm3) &optional copy)
  (check-type copy (or null sm3))
  (let ((copy (if copy copy (%make-sm3-digest))))
    (declare (type sm3 copy))
    (replace (sm3-state copy) (sm3-state state))
    (replace (sm3-buffer copy) (sm3-buffer state))
    (setf (sm3-count copy) (sm3-count state)
          (sm3-buffer-index copy) (sm3-buffer-index state))
    copy))

(define-digest-updater sm3
  (let ((s (sm3-state state))
        (count (sm3-count state))
        (buffer (sm3-buffer state))
        (buffer-index (sm3-buffer-index state))
        (length (- end start))
        (n 0))
    (declare (type (simple-array (unsigned-byte 32) (8)) s)
             (type (simple-array (unsigned-byte 8) (64)) buffer)
             (type (unsigned-byte 64) count)
             (type (integer 0 64) buffer-index n)
             (type fixnum length))
    (when (plusp buffer-index)
      (setf n (min length (- 64 buffer-index)))
      (replace buffer sequence :start1 buffer-index :start2 start :end2 (+ start n))
      (incf count n)
      (incf buffer-index n)
      (incf start n)
      (decf length n)
      (when (= buffer-index 64)
        (sm3-hash s buffer 0)
        (setf buffer-index 0)))

    (loop until (< length 64) do
      (sm3-hash s sequence start)
      (incf count 64)
      (incf start 64)
      (decf length 64))

    (when (plusp length)
      (replace buffer sequence :start2 start :end2 end)
      (incf count length)
      (setf buffer-index length))

    (setf (sm3-count state) count
          (sm3-buffer-index state) buffer-index)
    (values)))

(define-digest-finalizer (sm3 32)
  (let ((s (sm3-state state))
        (bit-count (* 8 (sm3-count state)))
        (buffer (sm3-buffer state))
        (buffer-index (sm3-buffer-index state)))
    (setf (aref buffer buffer-index) #x80)
    (incf buffer-index)
    (when (> buffer-index 56)
      (fill buffer 0 :start buffer-index :end 64)
      (sm3-hash s buffer 0)
      (setf buffer-index 0))
    (fill buffer 0 :start buffer-index :end 56)
    (setf (ub64ref/be buffer 56) bit-count)
    (sm3-hash s buffer 0)
    (let ((output (make-array 32 :element-type '(unsigned-byte 8))))
      (dotimes (i 8)
        (setf (ub32ref/be output (* i 4)) (aref s i)))
      (replace digest output :start1 digest-start)
      digest)))

(defdigest sm3 :digest-length 32 :block-length 64)
